lit_list = function(theta,phi) {
    theta_cat = round(theta *20)
    phi_rep = rep(phi, length(theta))
    return(ifelse(phi_rep > 0, (1/(20-theta_cat)), 0))
    
}

pragmatic_speaker = function(theta1,theta2,phi) {
  
  utt1_num = 0
  if (theta1 < phi) {
    utt1_num = lit_list(theta1, phi)
  } 
  
  utt2_num = 0
   if (theta2 < phi) {
    utt2_num = lit_list(theta2, phi)
   }
  
  other_num = 0.001
  
  denom = other_num + utt1_num + utt2_num
  return(utt1_num/denom)

}


uniform_dist = rep(1,21)

approx_pragmatic_speaker = function(theta1,theta2,phi, theta2_dist = uniform_dist) {
  
  utt1_num = 0
  if (theta1 < phi) {
    theta1_cat = round(theta1 *20)
    utt1_num = 1/(21-theta1_cat)
  }
  
  utt2_num = 0
   if (theta2 < phi) {
    if (theta2 < 1.1) {
      x = seq(0, min(1, phi), 0.05)
      y = lit_list(x, phi)
      utt2_num = mean(y)
    } else {
      x = seq(0.9, phi, 0.05)
      y = lit_list(x, phi)
      utt2_num = mean(y)
    }
   }
  
  other_num = 0.001
  
  denom = other_num + utt1_num + utt2_num
  return(utt1_num/denom)

}

crude_approx_pragmatic_speaker = function(theta1,theta2,phi, theta2_dist = uniform_dist) {
  
  utt1_num = 0
  if (theta1 < phi) {
    theta1_cat = round(theta1 *20)
    utt1_num = ifelse(phi > 0, (1/(20-theta1_cat)), 0)
  }
  
  utt2_num = 0
  if (phi > 0) {
    x = seq(0, phi-0.01, 0.05)
    y = lit_list(x, phi)
    t = (21 - length(y))
    
    if (length(y) < 21) {
      y = c(y, rep.int(0, t))
    }
    utt2_num = weighted.mean(y, w = theta2_dist)
  
  }
  other_num = 0.001
  
  denom = other_num + utt1_num + utt2_num
  return(utt1_num/denom)

}
  d = data.frame(phi = rep(seq(0,1,0.05), 1, each=441), theta1 = rep(seq(0,1,0.05), 21, each=21), theta2 = rep(seq(0,1,0.05), 441))
  
d$type = "exact"
d = d %>%
  rowwise() %>%
  mutate(speaker = pragmatic_speaker(theta1, theta2, phi))


d2 =  d %>%
  rowwise() %>%
  mutate(speaker = crude_approx_pragmatic_speaker(theta1, theta2, phi), type="crude_approx")

d = rbind(d,d2)

  d %>% ggplot(aes(x=theta2, y=speaker, col=type)) + geom_line() + geom_point() + facet_wrap(~theta1+phi,ncol=21)

  d = data.frame(phi = rep(seq(0,1,0.05), 1, each=441), theta1 = rep(seq(0,1,0.05), 21, each=21), theta2 = rep(seq(0,1,0.05), 441))
  
d$type = "exact"
d = d %>%
  rowwise() %>%
  mutate(speaker = pragmatic_speaker(theta1, theta2, phi))


d3 =  d %>%
  rowwise() %>%
  mutate(speaker = crude_approx_pragmatic_speaker(theta1, theta2, phi), type="crude_approx")


d = rbind(d, d3)


  d %>% group_by(type,phi) %>% summarize(speaker_m=mean(speaker)) %>%
    ggplot(aes(x=phi, y=speaker_m, col=type)) + geom_line() + geom_point()
## Warning: Grouping rowwise data frame strips rowwise nature

discrete_beta = function(a, b) {
  x = seq(0,1,0.05)
  x2 = seq(0.05,1.05,0.05)
  y = pbeta(x2, a, b) - pbeta(x, a, b)
  return(y)
}
plot_model_comparison = function(theta1_dist, theta2_dist, label = "uniform") {

  d = data.frame(phi = rep(seq(0,1,0.05), 1, each=441), theta1 = rep(seq(0,1,0.05), 21, each=21), theta2 = rep(seq(0,1,0.05), 441))
  

d$type = "Exact solution"
d_exact = d %>%
  rowwise() %>%
  mutate(speaker = pragmatic_speaker(theta1, theta2, phi))

d_exact = d_exact %>% 
  group_by(type, phi, theta1) %>% 
  summarize(speaker_m2 = weighted.mean(speaker, w = theta2_dist)) %>%
  group_by(type, phi) %>%
  summarize(speaker_m = weighted.mean(speaker_m2, w = theta1_dist))


d_approx =  d %>%
  rowwise() %>%
  mutate(speaker = crude_approx_pragmatic_speaker(theta1, theta2, phi, theta2_dist = theta2_dist), type="Approximation")

d_approx = d_approx %>%
  group_by(type, phi, theta1) %>% 
  summarize(speaker_m2 = weighted.mean(speaker, w = uniform_dist)) %>%
  group_by(type, phi) %>%
    summarize(speaker_m = weighted.mean(speaker_m2, w = theta1_dist))


d = rbind(d_exact, d_approx)
d$label = label

 p =  d %>%  
   ggplot(aes(x=phi, y=speaker_m, col=type)) + 
   geom_line() + 
   geom_point() + 
   ylim(0,1) +
   xlab("event probability") +
   ylab("expected pragmatic speaker") +
   facet_wrap(~label) +
   theme(legend.text = element_text(size=14),
        strip.text.x = element_text(size=14),
        axis.title = element_blank(),
        axis.text = element_text(size=12),
        legend.position = "bottom") +
   guides(col = guide_legend(title=""))

 return(p)
 
}

might_dist = discrete_beta(0.93,3.08)
probably_dist = discrete_beta(2.55,1.77)
bare_dist = discrete_beta(1.27, 0.29)

p3 = plot_model_comparison(might_dist, bare_dist, "might") + theme(legend.position = "none") + ggtitle("might-bare")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p4 = plot_model_comparison(bare_dist, might_dist, "bare") + theme(legend.position = "none")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p5 = plot_model_comparison(might_dist, probably_dist, "might") + theme(legend.position = "none") + ggtitle("might-probably")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p6 = plot_model_comparison(probably_dist, might_dist, "probably") + theme(legend.position = "none")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p7 = plot_model_comparison(probably_dist, bare_dist, "probably") + theme(legend.position = "none") + ggtitle("probably-bare")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p8 = plot_model_comparison(bare_dist, probably_dist, "bare") + theme(legend.position = "none")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p1 = plot_model_comparison(uniform_dist, uniform_dist) + theme(legend.position = "none") + ggtitle("uniform-uniform")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p2 = plot_model_comparison(uniform_dist, uniform_dist) + theme(legend.position = "none")
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
sim_plots = arrangeGrob(p1,p3,p5,p7,p2,p4,p6,p8, ncol = 4, 
  left = "Expected pragmatic speaker", 
  bottom="Event probability")

extract_legend = function (a.gplot) {
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)
}

legend = extract_legend(plot_model_comparison(uniform_dist, uniform_dist))
## Warning: Grouping rowwise data frame strips rowwise nature

## Warning: Grouping rowwise data frame strips rowwise nature
p = grid.arrange(sim_plots, legend, heights = c(7,1))

ggsave(p, filename = "../../papers/cognition/plots/approx-simulations.pdf", width = 30, height = 12, units = "cm")